import torch
from torch import nn
import torch.nn.functional as F


class Discriminator(object):
    def __init__(self, input_dim, hidden_dim, hidden_layers, dropout, target_classes):
        self.layers = nn.ModuleList()
        for i in range(hidden_layers - 1):
            self.layers.append(nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.PReLU(),
                nn.Dropout(dropout),
            ))
            input_dim = hidden_dim
        self.layers.append(nn.Linear(input_dim, target_classes))

    def __call__(self, x_dict):
        h = x_dict['h']
        for layer in self.layers:
            h = layer(h)
        return h

class AdversarialLatentLayer(nn.Module):
    def __init__(self, input_dims, label_key, discriminator_hidden, discriminator_layers, discriminator_dropout=0.1,
                 target_classes=2, disc_lr=1e-2, disc_wd=1e-6, **kwargs):
        super().__init__()
        self.source_dims = input_dims
        self.label_keys = label_key
        num_src_dim = sum(input_dims.int())
        self.discriminator = Discriminator(num_src_dim, discriminator_hidden, discriminator_layers,
                                           discriminator_dropout, target_classes)
        self.is_adversarial = True
        self.d_optimizer = torch.optim.AdamW(self.discriminator.layers.parameters(), lr=disc_lr, weight_decay=disc_wd)

    def forward(self, x_dict):
        h = x_dict['h']
        y = self.discriminator(h[:, self.source_dims])
        return h, -F.cross_entropy(y, x_dict[self.label_keys])

    def d_loss(self, x_dict):
        h = x_dict['h']
        h = h[:, self.source_dims]
        for layer in self.layers:
            h = layer(h)
        cse = nn.CrossEntropyLoss()
        return cse(h, x_dict[self.label_keys])

    def set_d_optimizer(self, lr=1e-2, wd=1e-6):
        self.d_optimizer = torch.optim.AdamW(self.discriminator.layers.parameters(), lr=lr, weight_decay=wd)

    def d_iter(self, x_dict):
        self.d_optimizer.zero_grad()
        loss = self.d_loss(x_dict)
        loss.backward()
        self.d_optimizer.step()
